"""
Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py
"""
from collections import OrderedDict
from typing import Tuple, Union

import hashlib
import os
import urllib
import warnings
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from einops import repeat
from ipdb import set_trace


from modules.ast_models import ASTModel
from modules import resnet_models

_MODELS = {
	"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
	"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
	"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
	"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
	"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
	"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
_PT_NAME = {
	"RN50": "RN50.pt",
	"RN101": "RN101.pt",
	"RN50x4": "RN50x4.pt",
	"RN50x16": "RN50x16.pt",
	"ViT-B/32": "ViT-B-32.pt",
	"ViT-B/16": "ViT-B-16.pt",
}

def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
	os.makedirs(root, exist_ok=True)
	filename = os.path.basename(url)

	expected_sha256 = url.split("/")[-2]
	download_target = os.path.join(root, filename)

	if os.path.exists(download_target) and not os.path.isfile(download_target):
		raise RuntimeError(f"{download_target} exists and is not a regular file")

	if os.path.isfile(download_target):
		if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
			return download_target
		else:
			warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

	with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
		with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
			while True:
				buffer = source.read(8192)
				if not buffer:
					break

				output.write(buffer)
				loop.update(len(buffer))

	if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
		raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

	return download_target

def available_models():
	"""Returns the names of available CLIP models"""
	return list(_MODELS.keys())

# =============================

class Bottleneck(nn.Module):
	expansion = 4

	def __init__(self, inplanes, planes, stride=1):
		super().__init__()

		# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
		self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
		self.bn1 = nn.BatchNorm2d(planes)

		self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(planes)

		self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

		self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
		self.bn3 = nn.BatchNorm2d(planes * self.expansion)

		self.relu = nn.ReLU(inplace=True)
		self.downsample = None
		self.stride = stride

		if stride > 1 or inplanes != planes * Bottleneck.expansion:
			# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
			self.downsample = nn.Sequential(OrderedDict([
				("-1", nn.AvgPool2d(stride)),
				("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
				("1", nn.BatchNorm2d(planes * self.expansion))
			]))

	def forward(self, x: torch.Tensor):
		identity = x

		out = self.relu(self.bn1(self.conv1(x)))
		out = self.relu(self.bn2(self.conv2(out)))
		out = self.avgpool(out)
		out = self.bn3(self.conv3(out))

		if self.downsample is not None:
			identity = self.downsample(x)

		out += identity
		out = self.relu(out)
		return out


class AttentionPool2d(nn.Module):
	def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
		super().__init__()
		self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
		self.k_proj = nn.Linear(embed_dim, embed_dim)
		self.q_proj = nn.Linear(embed_dim, embed_dim)
		self.v_proj = nn.Linear(embed_dim, embed_dim)
		self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
		self.num_heads = num_heads

	def forward(self, x):
		x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
		x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
		x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
		x, _ = F.multi_head_attention_forward(
			query=x, key=x, value=x,
			embed_dim_to_check=x.shape[-1],
			num_heads=self.num_heads,
			q_proj_weight=self.q_proj.weight,
			k_proj_weight=self.k_proj.weight,
			v_proj_weight=self.v_proj.weight,
			in_proj_weight=None,
			in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
			bias_k=None,
			bias_v=None,
			add_zero_attn=False,
			dropout_p=0,
			out_proj_weight=self.c_proj.weight,
			out_proj_bias=self.c_proj.bias,
			use_separate_proj_weight=True,
			training=self.training,
			need_weights=False
		)

		return x[0]


class ModifiedResNet(nn.Module):
	"""
	A ResNet class that is similar to torchvision's but contains the following changes:
	- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
	- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
	- The final pooling layer is a QKV attention instead of an average pool
	"""

	def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
		super().__init__()
		self.output_dim = output_dim
		self.input_resolution = input_resolution

		# the 3-layer stem
		self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(width // 2)
		self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(width // 2)
		self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
		self.bn3 = nn.BatchNorm2d(width)
		self.avgpool = nn.AvgPool2d(2)
		self.relu = nn.ReLU(inplace=True)

		# residual layers
		self._inplanes = width  # this is a *mutable* variable used during construction
		self.layer1 = self._make_layer(width, layers[0])
		self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
		self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
		self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

		embed_dim = width * 32  # the ResNet feature dimension
		self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

	def _make_layer(self, planes, blocks, stride=1):
		layers = [Bottleneck(self._inplanes, planes, stride)]

		self._inplanes = planes * Bottleneck.expansion
		for _ in range(1, blocks):
			layers.append(Bottleneck(self._inplanes, planes))

		return nn.Sequential(*layers)

	def forward(self, x):
		def stem(x):
			for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
				x = self.relu(bn(conv(x)))
			x = self.avgpool(x)
			return x

		x = x.type(self.conv1.weight.dtype)
		x = stem(x)
		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)
		x = self.layer4(x)
		x = self.attnpool(x)

		return x


class LayerNorm(nn.LayerNorm):
	"""Subclass torch's LayerNorm to handle fp16."""

	def forward(self, x: torch.Tensor):
		orig_type = x.dtype
		ret = super().forward(x.type(torch.float32))
		return ret.type(orig_type)


class QuickGELU(nn.Module):
	def forward(self, x: torch.Tensor):
		return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
	def __init__(self, d_model: int, n_head: int, opt=None, attn_mask=None, space_time=False, audio_vis=False, layer_idx=0):
		super().__init__()

		### spatial attn
		self.attn = nn.MultiheadAttention(d_model, n_head)
		self.ln_1 = LayerNorm(d_model)
		self.layer_idx = layer_idx
		self.space_time = space_time
		self.audio_vis = audio_vis
		self.opt = opt
		if self.space_time:
		   ## temporal attn
		   self.temporal_attn = nn.MultiheadAttention(d_model, n_head)
		   self.temporal_ln_1 = LayerNorm(d_model)
		   self.temporal_fc = nn.Linear(d_model, d_model)


		   self.temporal_av_attn = nn.MultiheadAttention(d_model, n_head)
		   self.temporal_av_ln_1 = LayerNorm(d_model)
		   self.temporal_av_fc = nn.Linear(d_model, d_model)


		self.audio_vis = audio_vis
		if self.audio_vis:
			## audio attn
		   self.audio_vis_attn = nn.MultiheadAttention(d_model, n_head)
		   self.vis_audio_attn = nn.MultiheadAttention(d_model, n_head)

		   self.audio_vis_ln_1 = LayerNorm(d_model)
		   self.audio_vis_fc = nn.Linear(d_model, d_model)

		   self.vis_audio_ln_1 = LayerNorm(d_model)
		   self.vis_audio_fc = nn.Linear(d_model, d_model)


		   self.audio_a_attn = nn.MultiheadAttention(d_model, n_head)
		   self.audio_a_ln_1 = LayerNorm(d_model)
		   self.audio_a_fc = nn.Linear(d_model, d_model)


		   
			
		### MLP
		self.mlp = nn.Sequential(OrderedDict([
			("c_fc", nn.Linear(d_model, d_model * 4)),
			("gelu", QuickGELU()),
			("c_proj", nn.Linear(d_model * 4, d_model))
		]))
		self.ln_2 = LayerNorm(d_model)
		self.attn_mask = attn_mask

	def attention(self, x: torch.Tensor):
		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]

	def attention_audio(self, x: torch.Tensor):
		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		return self.audio_a_attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]

	def temporal_attention_ori(self, x: torch.Tensor, video_shape=None):
		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape

		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		if attn_mask_ is not None:
			print(attn_mask_.size())

		x = rearrange(x, 'n (b p t l) m -> t (b p n l) m', b=b, p=pair, t=bs, l=ts)
		x = self.temporal_attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
		x = rearrange(x, 't (b p n l) m -> n (b p t l) m', b=b, p=pair, t=bs, l=ts)
		return x
	def temporal_attention_bk(self, x: torch.Tensor, video_shape=None):
		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape

		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		if attn_mask_ is not None:
			print(attn_mask_.size())
		

		x_a_cls = x[:2]
		x = x[2:]


		x = rearrange(x, 'n (b p t l) m -> t (b p n l) m', b=b, p=pair, t=bs, l=ts)
		x_a_cls = rearrange(x_a_cls, 'n (b p t l) m -> t (b p n l) m', b=b, p=pair, t=bs, l=ts)

		x_tmp = torch.cat((x_a_cls,x),dim=1)
		# x = self.temporal_attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
		x_tmp = self.temporal_attn(x_tmp, x_tmp, x_tmp, need_weights=False, attn_mask=attn_mask_)[0]

		x = x_tmp[:,x_a_cls.size(1):]
		x_a_cls = x_tmp[:,:x_a_cls.size(1)]

		x = rearrange(x, 't (b p n l) m -> n (b p t l) m', b=b, p=pair, t=bs, l=ts)
		x_a_cls = rearrange(x_a_cls, 't (b p n l) m -> n (b p t l) m', b=b, p=pair, t=bs, l=ts)
		
		return self.temporal_fc(x), self.temporal_fc(x_a_cls)
	def temporal_attention(self, x: torch.Tensor, video_shape=None, audio=True):
		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape

		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		if attn_mask_ is not None:
			print(attn_mask_.size())

		if audio:
			x = rearrange(x, 'n (b p t l) m -> t (b p n l) m', b=b, p=pair, t=bs*2, l=ts)
			x = self.temporal_av_attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
			x = rearrange(x, 't (b p n l) m -> n (b p t l) m', b=b, p=pair, t=bs*2, l=ts)
		else:
			x = rearrange(x, 'n (b p t l) m -> t (b p n l) m', b=b, p=pair, t=bs, l=ts)
			x = self.temporal_attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
			x = rearrange(x, 't (b p n l) m -> n (b p t l) m', b=b, p=pair, t=bs, l=ts)
		return x

	def av_attention(self, x: torch.Tensor, video_shape=None):
		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape

		attn_mask_ = self.attn_mask
		if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			attn_mask_ = self.attn_mask(x.size(0))   # LND

		attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None

		if attn_mask_ is not None:
			print(attn_mask_.size())

		assert len(x)==2, 'error in audio-visaul blocks'
		x, x_a = x

		

		if self.opt.yb_dual:


			### new debug ??? att: column normalization
			# att_weight = self.vis_audio_attn(x_a, x, x_a.repeat(50,1,1), need_weights=True, attn_mask=attn_mask_)[1]
			# res_x = torch.bmm(att_weight.permute(0,2,1),x_a.permute(1,0,2)).permute(1,0,2)

			# return self.vis_audio_fc(res_x),\
			# 	self.audio_vis_fc(self.audio_vis_attn(x_a, x, x, need_weights=False, attn_mask=attn_mask_)[0])
			##

			
			return self.vis_audio_fc(self.vis_audio_attn(x, x_a, x_a, need_weights=True, attn_mask=attn_mask_)[0]),\
				self.audio_vis_fc(self.audio_vis_attn(x_a, x, x, need_weights=False, attn_mask=attn_mask_)[0])
		else:
			return self.vis_audio_fc(self.vis_audio_attn(x, x_a, x_a, need_weights=False, attn_mask=attn_mask_)[0]), x_a
		# 
		#


	def forward(self, x_tuple:tuple):
		x, video_frame, video_shape = x_tuple

		
		if len(x) == 2:
			x, x_a = x
			# x_a_cls = x_a[:2]
			# x_a = x_a[2:]

	

		###　audio-visual attention
		if self.audio_vis and self.layer_idx < self.opt.yb_start_layer:
			# x_a = x_a + self.attention_audio(self.audio_a_ln_1(x_a))

			### AST !!!!!
			# x_a = self.audio_a_attn(x_a)
			## end
			x_res, x_a_res = self.av_attention([self.vis_audio_ln_1(x), self.audio_vis_ln_1(x_a)], video_shape)

			x = x + x_res
			if self.opt.yb_dual:
				x_a = x_a + x_a_res	


			### self AST  ###
			
			# attn_mask_ = self.attn_mask
			# if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
			#     attn_mask_ = self.attn_mask(x_a.size(0))   # LND

			# attn_mask_ = attn_mask_.to(dtype=x_a.dtype, device=x.device) if attn_mask_ is not None else None

			# if attn_mask_ is not None:
			#     print(attn_mask_.size())


			
			# x_a_tmp = x_a
			# x_a_tmp = self.audio_a_ln_1(x_a_tmp)
			# x_a_tmp = self.audio_a_fc(self.audio_a_attn(x_a_tmp, x_a_tmp, x_a_tmp, need_weights=False, attn_mask=attn_mask_)[0])

			# x_a = x_a + x_a_tmp
			# ### AST end ###
		else:
			return ((x,x_a), video_frame, video_shape)

		if self.space_time:
		   x = x + self.temporal_fc(self.temporal_attention_ori(self.temporal_ln_1(x), video_shape))

		### temporal attention
		# if self.space_time:
			
			
		#     b, pair, bs, ts, channel, h, w = video_shape
			
		#     x_tmp = rearrange(x, 'n (b p t l) m -> n b t (p l) m', b=b, p=pair, t=bs, l=ts)
		#     x_a_tmp = rearrange(x_a_cls, 'n (b p t l) m -> n b t (p l) m', b=b, p=pair, t=bs, l=ts)


		#     # x_a_tmp_1 = repeat(x_a_tmp[:1], 'p bs t l d -> (p repeat) bs t l d ', repeat=50)
		#     # x_a_tmp_2 = repeat(x_a_tmp[1:], 'p bs t l d -> (p repeat) bs t l d ', repeat=50)

			

		#     x_tmp_cls = torch.cat(( (x_a_tmp[:1]+x_a_tmp[1:])/2 ,x_tmp[:1]), dim=2)

			
		#     x_tmp_cls = rearrange(x_tmp_cls, 'n b t l m -> n (b t l) m ', b=b, t=bs*2, l=ts) # b:16 pair:1 bs:8 ts:1
		#     x_tmp = rearrange(x_tmp[1:], 'n b t l m -> n (b t l) m ', b=b, t=bs, l=ts)
		#     # 50 16 24 1 768

			
		#     x_tmp_cls = self.temporal_av_fc(self.temporal_attention(self.temporal_av_ln_1(x_tmp_cls), video_shape))
		#     x_tmp_cls = rearrange(x_tmp_cls, 'n (b t l) m -> n b t l m ', b=b, t=bs*2, l=ts)


		#     x_tmp_a_cls = rearrange(x_tmp_cls[:,:,:bs], 'n b t l m -> n (b t l) m ', b=b, t=bs, l=ts)

		#     x_tmp_cls = rearrange(x_tmp_cls[:,:,bs:], 'n b t l m -> n (b t l) m ', b=b, t=bs, l=ts)

		#     x_tmp = self.temporal_fc(self.temporal_attention(self.temporal_ln_1(x_tmp), video_shape, audio=False))
		#     # x_tmp = rearrange(x_tmp, 'n (b t l) m -> n b t l m ', b=b, t=bs, l=ts)
		#     # x_tmp_cls = rearrange(x_tmp_cls[:,:,bs*2:], 'n b t l m -> n (b t l) m ', b=b, t=bs, l=ts)

		#     x = x + torch.cat((x_tmp_cls,x_tmp), dim=0)

		### spatial attention
		x = x + self.attention(self.ln_1(x))

		### MLP
		x = x + self.mlp(self.ln_2(x))


		if self.audio_vis:
			return ((x,x_a), video_frame, video_shape)
		else:
			return (x, video_frame, video_shape)

	


class Transformer(nn.Module):
	def __init__(self, width: int, layers: int, heads: int, opt=None ,attn_mask = None, space_time=False, audio_vis=False):
		super().__init__()
		self.width = width
		self.layers = layers
		self.opt = opt
		self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, opt,attn_mask, space_time, audio_vis, layer_idx) for layer_idx in range(layers)])
		self.space_time = space_time
		self.audio_vis = audio_vis

	def forward(self, x: torch.Tensor, video_frame=-1, video_shape=None):
		return self.resblocks((x, video_frame,video_shape))[0]


class VisualTransformer(nn.Module):
	def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
				 linear_patch: str = '2d', opt=None):
		super().__init__()
		self.opt = opt
		self.input_resolution = input_resolution
		self.output_dim = output_dim

		self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

		scale = width ** -0.5
		self.class_embedding = nn.Parameter(scale * torch.randn(width))

		self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))

		self.positional_embedding_tmp = nn.Parameter(torch.zeros(self.opt.max_frames, width, dtype=float))

		# self.positional_audio= nn.Parameter(torch.zeros(217, width, dtype=float))
		self.positional_audio= nn.Parameter(torch.zeros(441, width, dtype=float))
		self.audio_cls = nn.Parameter(scale * torch.zeros(width))
		self.audio_dist = nn.Parameter(scale * torch.zeros(width))


		

		self.mlp_audio = nn.Linear(512,768)
		self.mlp_audio_k = nn.Linear(768+50*32,opt.audio_cluster)
		self.mlp_vis = nn.Linear(768,32)

		self.ln_pre = LayerNorm(width)
		# self.transformer = Transformer(width, layers, heads)
		self.transformer = Transformer(width, layers, heads, opt, space_time=False, audio_vis=self.opt.yb_av)

		self.ln_post = LayerNorm(width)
		self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

		self.conv_audio = nn.Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))

		# self.AST = ASTModel(label_dim=512, fstride=10, tstride=10, input_fdim=128,
		#                           input_tdim=1024, imagenet_pretrain=True,
		#                           audioset_pretrain=True, model_size='base384')
		# print(self.AST.v.blocks[0].attn.qkv.weight)
		# For 3D
		assert linear_patch in ['2d', '3d']
		self.linear_patch = linear_patch
		if self.linear_patch == '3d':
			self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size),
								   stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False)

		# self.resnet= resnet_models.AVENet('7414') 
		# checkpoint = torch.load('../VGGSound/H.pth.tar')
		# self.resnet.load_state_dict(checkpoint['model_state_dict'])
		# 

		# self.resnet= resnet_models.AVENet('7414') 
		# checkpoint = torch.load('../VGGSound/H.pth.tar', map_location=torch.device('cpu'))
		# self.resnet.load_state_dict(checkpoint['model_state_dict'])
		# self.resnet.eval()

	def forward(self, x: torch.Tensor, audio: torch.Tensor, video_frame=-1, video_shape=None, is_train=True):


		if self.linear_patch == '3d':
			assert video_frame != -1
			x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1])
			x_3d = x_3d.permute(0, 2, 1, 3, 4)
			x_3d = self.conv2(x_3d)     # shape = [*, width, frame, grid, grid]
			x_3d = x_3d.permute(0, 2, 1, 3, 4)      # shape = [*, frame, width, grid, grid]
			x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
		else:
			x = self.conv1(x)  # shape = [*, width, grid, grid]

		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape



		x = x.view(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
		x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

		x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]  
		x = x + self.positional_embedding.to(x.dtype) # spatial attention


  

		### yb: time embedding #########
		if self.transformer.space_time:
			x = rearrange(x, '(b p t l) n m -> (b p n l) t m', b=b, p=pair, t=bs, l=ts)
			x = x + self.positional_embedding_tmp.to(x.dtype) 
			x = rearrange(x, '(b p n l) t m -> (b p t l) n m', b=b, p=pair, t=bs, l=ts)
		#### end ###

		## yb: audio embed raw audio ###
		# set_trace()
		# audio = audio.unsqueeze(1)
		# audio = audio.transpose(2, 3)
		
		# audio = self.conv_audio(audio)
		
		# audio = audio.flatten(2).transpose(1, 2)


		### 3 AST ###
		# audio = rearrange(audio, 'b t p d -> (b t) 1 p d')

		# audio = self.conv_audio(audio)
		# _,_, h_a, w_a = audio.shape
		# audio = audio.flatten(2).transpose(1, 2)
		# # audio = audio + self.positional_audio.to(audio.dtype)
		# audio = torch.cat(
		#     [self.audio_cls.to(audio.dtype) + torch.zeros(audio.shape[0], 1, audio.shape[-1], dtype=audio.dtype, device=audio.device),
		#     self.audio_dist.to(audio.dtype) + torch.zeros(audio.shape[0], 1, audio.shape[-1], dtype=audio.dtype, device=audio.device), 
		#     audio], dim=1)  # shape = [*, grid ** 2 + 1, width]  
		# audio = audio.permute(1, 0, 2)
		

		###############################
		# end audio ###
		
		
		x = self.ln_pre(x)

		## audio pre-extract ##

		

		# audio = rearrange(audio, 'b t p d -> t (b p) d') ### S3D
		##


		x = x.permute(1, 0, 2)  # NLD -> LND
		# ori: (50, 128, 768) -> (7*7+1, 16(bs)*8(frames), 768)


		if self.transformer.audio_vis:

			if self.opt.yb_time_cross_audio:
				audio = repeat(audio, 'b t p d -> b c t p d', c=self.opt.max_audio_frames)
				audio = audio.squeeze(3) # bs x repeat x audio time x dim

			audio = rearrange(audio, 'b t p d -> p (b t) d')
			# audio_ori = rearrange(audio[2:], '(w h) t d -> w h t d', w=12, h=101)
			# audio_ori = rearrange(audio[2:], '(w h) t d -> w h t d', w=w_a, h=h_a)

			# perm = torch.randperm(audio_ori.size(1))
			# # idx = perm[:self.opt.audio_cluster]
			# # rand_rand = torch.randint(1, self.opt.audio_cluster,(1,1)).item()
			# idx = torch.randint(0,audio_ori.size(1)-self.opt.audio_cluster, (1,1))

			# smaple_audio = audio_ori[:, idx:idx+self.opt.audio_cluster]

			# smaple_audio = rearrange(smaple_audio, 'w h t d -> (w h) t d')
			# tmp_audio = torch.cat((audio[:2], smaple_audio), dim=0)
			# tmp_audio = torch.cat((audio[:1], smaple_audio), dim=0)

			# x,_ = self.transformer([x, audio[:2]], video_frame=video_frame, video_shape=video_shape)

			
			x, x_a = self.transformer([x, self.mlp_audio(audio)], video_frame=video_frame, video_shape=video_shape)
		else:
			# x,_ = self.transformer([x, audio[:2]], video_frame=video_frame, video_shape=video_shape)
			x = self.transformer(x,  video_frame=video_frame, video_shape=video_shape)
			# x,_ = self.transformer([x, audio], video_frame=video_frame, video_shape=video_shape)

		
		# x = self.transformer(torch.cat((x,self.mlp_audio(audio)),dim=0),  video_frame=video_frame, video_shape=video_shape)
		
		
		x = x.permute(1, 0, 2)  # LND -> NLD


		# print(self.AST.v.blocks[0].attn.qkv.weight)
		# set_trace()
		# # 0.0138/0.0295
		# x_a = self.AST(rearrange(audio.float(), 'b p t f dim ->  (b p t) f dim', b=b, p=pair, t=self.opt.max_audio_frames))
		# x_a = rearrange(x_a, '(b p t) dim ->  b p t dim', b=b, p=pair, t=self.opt.max_audio_frames) #.mean(dim=-2)

		# print(self.AST.v.blocks[0].attn.qkv.weight)

		##### yb s-t-att #
		# x = x.view(int(x.size(0)/video_frame), video_frame, x.size(1),-1)
		# x = x[:,:,0,:]

		# att_weight = self.mlp_att_yb(x)
		# att_weight = F.softmax(att_weight, dim=1)
		# x = torch.bmm(att_weight.permute(0,2,1), x)
		### 

		# Move the three lines below to `encode_image` for entire hidden sequence
		# x = self.ln_post(x[:, 0, :])
		# if self.proj is not None:
		#     x = x @ self.proj

		return (x,x_a)

class CLIP(nn.Module):
	def __init__(self,
				 embed_dim: int,
				 # vision
				 image_resolution: int,
				 vision_layers: Union[Tuple[int, int, int, int], int],
				 vision_width: int,
				 vision_patch_size: int,
				 # text
				 context_length: int,
				 vocab_size: int,
				 transformer_width: int,
				 transformer_heads: int,
				 transformer_layers: int,
				 # vision linear of patch
				 linear_patch: str = '2d',
				# config
				 opt = None,
				 ):
		super().__init__()

		self.context_length = context_length
		self.opt = opt
		if isinstance(vision_layers, (tuple, list)):
			vision_heads = vision_width * 32 // 64
			self.visual = ModifiedResNet(
				layers=vision_layers,
				output_dim=embed_dim,
				heads=vision_heads,
				input_resolution=image_resolution,
				width=vision_width
			)
		else:
			vision_heads = vision_width // 64
			self.visual = VisualTransformer(
				input_resolution=image_resolution,
				patch_size=vision_patch_size,
				width=vision_width,
				layers=vision_layers,
				heads=vision_heads,
				output_dim=embed_dim,
				linear_patch=linear_patch,
				opt=self.opt
			)

		self.transformer = Transformer(
			width=transformer_width,
			layers=transformer_layers,
			heads=transformer_heads,
			attn_mask=self.build_attention_mask
		)

		self.vocab_size = vocab_size
		self.token_embedding = nn.Embedding(vocab_size, transformer_width)
		self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
		# self.positional_embedding = nn.Parameter(torch.zeros(self.context_length, transformer_width))
		self.ln_final = LayerNorm(transformer_width)

		self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
		self.logit_scale = nn.Parameter(torch.ones([]))

		self.initialize_parameters()

	def initialize_parameters(self):
		nn.init.normal_(self.token_embedding.weight, std=0.02)
		nn.init.normal_(self.positional_embedding, std=0.01)

		if isinstance(self.visual, ModifiedResNet):
			if self.visual.attnpool is not None:
				std = self.visual.attnpool.c_proj.in_features ** -0.5
				nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
				nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
				nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
				nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

			for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
				for name, param in resnet_block.named_parameters():
					if name.endswith("bn3.weight"):
						nn.init.zeros_(param)

		proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
		attn_std = self.transformer.width ** -0.5
		fc_std = (2 * self.transformer.width) ** -0.5
		for block in self.transformer.resblocks:
			nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
			nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
			nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
			nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

		if self.text_projection is not None:
			nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

	@staticmethod
	def get_config(pretrained_clip_name="ViT-B/32"):
		model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt")
		if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME:
			model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name])

		if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path):
			pass
		else:
			if pretrained_clip_name in _MODELS:
				model_path = _download(_MODELS[pretrained_clip_name])
			elif os.path.isfile(pretrained_clip_name):
				model_path = pretrained_clip_name
			else:
				raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}")

		try:
			# loading JIT archive
			model = torch.jit.load(model_path, map_location="cpu").eval()
			state_dict = model.state_dict()
		except RuntimeError:
			state_dict = torch.load(model_path, map_location="cpu")

		return state_dict

	def build_attention_mask(self, context_length):
		# lazily create causal attention mask, with full attention between the vision tokens
		# pytorch uses additive attention mask; fill with -inf
		mask = torch.zeros(context_length, context_length)
		mask.fill_(float("-inf"))
		mask.triu_(1)  # zero out the lower diagonal
		return mask

	@property
	def dtype(self):
		return self.visual.conv1.weight.dtype

	# add audio yb
	def encode_image(self, image, audio, return_hidden=False, video_frame=-1, video_shape=None, is_train=False):
		if video_shape is not None:
			b, pair, bs, ts, channel, h, w = video_shape
		
		hidden = self.visual(image.type(self.dtype), audio.type(self.dtype), video_frame=video_frame, video_shape=video_shape, is_train=is_train)
		# 
		# hidden = self.visual.ln_post(hidden) @ self.visual.proj
		

		# x = hidden[:, 0, :] 

		
		if return_hidden:
			return x, hidden
		## --> for attention visualization hidden = [x,x_a]
		return hidden
		## <----
		# return x # yb:supposed to be x 

	def encode_text(self, text, return_hidden=False):
		x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

		
		pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)

		if pos_emd.size(0) != x.size(1):
			pos_emd = F.interpolate(pos_emd.unsqueeze(0).permute(0,2,1) ,  x.size(1), mode='linear').squeeze(0).permute(1,0)

		x = x + pos_emd
		x = x.permute(1, 0, 2)  # NLD -> LND
		x = self.transformer(x)
		x = x.permute(1, 0, 2)  # LND -> NLD

		hidden = self.ln_final(x).type(self.dtype) @ self.text_projection

		# x.shape = [batch_size, n_ctx, transformer.width]
		# take features from the eot embedding (eot_token is the highest number in each sequence)
		x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]

		if return_hidden:
			return x, hidden

		return x

	def forward(self, image, text):
		image_features = self.encode_image(image)
		text_features = self.encode_text(text)

		# normalized features
		image_features = image_features / image_features.norm(dim=-1, keepdim=True)
		text_features = text_features / text_features.norm(dim=-1, keepdim=True)

		# cosine similarity as logits
		logit_scale = self.logit_scale.exp()
		logits_per_image = logit_scale * image_features @ text_features.t()
		logits_per_text = logit_scale * text_features @ image_features.t()

		# shape = [global_batch_size, global_batch_size]
		return logits_per_image, logits_per_text


def convert_weights(model: nn.Module):
	"""Convert applicable model parameters to fp16"""

	def _convert_weights_to_fp16(l):
		if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
			l.weight.data = l.weight.data.half()
			if l.bias is not None:
				l.bias.data = l.bias.data.half()

		if isinstance(l, nn.MultiheadAttention):
			for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
				tensor = getattr(l, attr)
				if tensor is not None:
					tensor.data = tensor.data.half()

		for name in ["text_projection", "proj"]:
			if hasattr(l, name):
				attr = getattr(l, name)
				if attr is not None:
					# add here
					if isinstance(attr, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
						attr.weight.data = attr.weight.data.half()
						if attr.bias is not None:
							attr.bias.data = attr.bias.data.half()
					## end ##
					else:
						attr.data = attr.data.half()
					
	

	model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict):
	vit = "visual.proj" in state_dict

	if vit:
		vision_width = state_dict["visual.conv1.weight"].shape[0]
		vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
		vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
		grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
		image_resolution = vision_patch_size * grid_size
	else:
		counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
		vision_layers = tuple(counts)
		vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
		output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
		vision_patch_size = None
		assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
		image_resolution = output_width * 32

	embed_dim = state_dict["text_projection"].shape[1]
	context_length = state_dict["positional_embedding"].shape[0]
	vocab_size = state_dict["token_embedding.weight"].shape[0]
	transformer_width = state_dict["ln_final.weight"].shape[0]
	transformer_heads = transformer_width // 64
	transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

	model = CLIP(
		embed_dim,
		image_resolution, vision_layers, vision_width, vision_patch_size,
		context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
	)

	for key in ["input_resolution", "context_length", "vocab_size"]:
		if key in state_dict:
			del state_dict[key]

	convert_weights(model)
	model.load_state_dict(state_dict)
	return model.eval()
